tensorflow>=2.3.1
numpy>=1.16.4
jax>=0.4.4
jaxlib>=0.4.4
flax>=0.6.4